import pandas as pd 
import numpy as np 
import statsmodels.api as sm 
import os 
import sys 
import yaml
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV
from sklearn.model_selection import KFold
from pprint import pprint
import random
random.seed(50)
from statistics import mean
from matplotlib import ticker
from matplotlib.ticker import FormatStrFormatter
import math
import pickle

def k_fold_split(random_state=50,
                 n_splits=5,
                 eitc=True,
                 dep_database=True,
                 datapath='/REDACTED/fairness/code/rf/data/'):

    # load config
    
    os.chdir('/REDACTED/fairness/code/rf/config')
    stream = open('data-config.yaml', 'r')
    out = yaml.safe_load(stream)
    print('config file loaded')

    if dep_database==True:
        df=pd.read_csv(datapath+'clean_rf_data_plus_dep_database.csv')
        df=df.loc[(df.activity_code==270) | (df.activity_code==271)]
    elif eitc==True and dep_database==False:
        df=pd.read_csv(datapath+'clean_rf_data.csv')
        df=df.loc[(df.activity_code==270) | (df.activity_code==271)]
    elif eitc==False and dep_database==True:
        print('error: dep_database features not available for full population.')
    else:
        df=pd.read_csv('clean_rf_data.csv')

    kf = KFold(n_splits=5)
    df = df.sample(frac=1, random_state=50).reset_index()

    i = 0
    for train, test in kf.split(df):
        df.loc[train].to_csv(datapath + 'train_data_' + str(i) + '_eitc_' + str(eitc) + '_dep_database_' + str(dep_database) + '.csv')
        df.loc[test].to_csv(datapath + 'test_data_' + str(i) + '_eitc_' + str(eitc) + '_dep_database_' + str(dep_database) + '.csv')
        i += 1

    return kf.split(df)

def tune_model(train_fold=0, # options are 0, 1, 2, 3, or 4
                eitc=True,
                dep_database=True,
                datapath='/REDACTED/fairness/code/rf/data/',
                model_type='reg', ## options are 'reg', 'cls'
                threshold=None, ## threshold for classifier model
                rs=50, ## random state
                ts=0.25, ## test size
                cvs=5,
                njs=10):

    # load config
    
    os.chdir('/REDACTED/fairness/code/rf/config')
    stream = open('data-config.yaml', 'r')
    out = yaml.safe_load(stream)
    print('config file loaded')

    # load data

    df = pd.read_csv(datapath + 'train_data_' + str(train_fold) + '_eitc_' + str(eitc) + '_dep_database_' + str(dep_database) + '.csv')
    
    # define features and labels
    
    if dep_database==True:
        feature_vars = [x for x in out['features_plus_dep_database_str'] if x in df.columns]
        features = df[feature_vars]
        if model_type=='reg':
            labels=df['chg_in_tax_owed_pv']
        elif model_type=='cls':
            df['tc_' + str(threshold)] = [1 if x>=threshold else 0 for x in df.chg_in_tax_owed_pv]
            labels=df['tc_' + str(threshold)]
    elif dep_database==False:
        feature_vars = [x for x in out['features_str'] if x in df.columns]
        features = df[feature_vars]
        if model_type=='reg':
            labels=df['chg_in_tax_owed_pv']
        elif model_type=='cls':
            df['tc_' + str(threshold)] = [1 if x>=threshold else 0 for x in df.chg_in_tax_owed_pv]
            labels=df['tc_' + str(threshold)]

    # define grid for hyperparam search
    
    n_estimators = [int(x) for x in np.linspace(start=1000, stop=2000, num=50)]
    max_features = ['auto', 'sqrt']
    max_depth = [int(x) for x in np.linspace(10, 110, num=20)]
    max_depth.append(None)
    min_samples_split = [2, 5, 10, 20, 30, 40, 50]
    min_samples_leaf = [1, 2, 4, 5, 10]
    bootstrap = [True, False]
    
    random_grid = {'n_estimators': n_estimators,
                  'max_features': max_features,
                  'max_depth': max_depth,
                  'min_samples_split': min_samples_split,
                  'min_samples_leaf': min_samples_leaf,
                  'bootstrap': bootstrap}
    
    # tune
    if model_type == 'reg':
        rf = RandomForestRegressor()
        rf_random = RandomizedSearchCV(estimator=rf, 
                               param_distributions=random_grid, 
                               cv=cvs, 
                               verbose=2, 
                               n_jobs=njs,
					 n_iter=1000)

        rf_random.fit(features, labels)
        
        # write out best params
        if dep_database==True:
            with open(datapath+'eitc_dep_database_reg_best_params_train_fold_' + str(train_fold) + '.pickle', 'wb') as handle:
                pickle.dump(rf_random.best_params_, handle, protocol=pickle.HIGHEST_PROTOCOL)
        elif dep_database==False and eitc==True:
            with open(datapath+'eitc_reg_best_params_train_fold_' + str(train_fold) + '.pickle', 'wb') as handle:
                pickle.dump(rf_random.best_params_, handle, protocol=pickle.HIGHEST_PROTOCOL)
        elif eitc==False:
            with open(datapath+'fullpop_reg_best_params_train_fold_' + str(train_fold) + '.pickle', 'wb') as handle:
                pickle.dump(rf_random.best_params_, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    elif model_type == 'cls':
        rf = RandomForestClassifier()
        rf_random = RandomizedSearchCV(estimator=rf, 
                               param_distributions=random_grid,  
                               cv=cvs, 
                               verbose=2, 
                               n_jobs=njs,
					 n_iter=1000)
                              
        rf_random.fit(features, labels)
        
         # write out best params
        if dep_database==True:
            with open(datapath+'eitc_dep_database_cls_' + str(threshold) + '_best_params_train_fold_' + str(train_fold) + '.pickle', 'wb') as handle:
                pickle.dump(rf_random.best_params_, handle, protocol=pickle.HIGHEST_PROTOCOL)
        elif dep_database==False and eitc==True:
            with open(datapath+'eitc_cls_' + str(threshold) + '_best_params_train_fold_' + str(train_fold) + '.pickle', 'wb') as handle:
                pickle.dump(rf_random.best_params_, handle, protocol=pickle.HIGHEST_PROTOCOL)
        elif eitc==False:
            with open(datapath+'fullpop_cls_' + str(threshold) + '_best_params_train_fold_' + str(train_fold) + '.pickle', 'wb') as handle:
                pickle.dump(rf_random.best_params_, handle, protocol=pickle.HIGHEST_PROTOCOL)

    return rf_random.best_params_
